import matplotlib.pyplot as plt
import matplotlib as mpl
import ipywidgets as widgets
from ipywidgets import interactive_output, HBox, VBox
import numpy as np

# Style sûr et dispo partout
plt.style.use('ggplot')
mpl.rcParams['axes.edgecolor'] = 'black'
mpl.rcParams['axes.linewidth'] = 1.2

# Fonction pour calculer les quantités à un avancement donné
def quantites(n1, n2, x):
    return {
        'Cu²⁺': max(n1 - x, 0),
        'OH⁻': max(n2 - 2 * x, 0),
        'Cu(OH)₂': x,
    }

# Couleurs des espèces chimiques
couleurs = {
    'Cu²⁺': '#00bcd4',    # cyan léger
    'OH⁻': '#4caf50',     # vert
    'Cu(OH)₂': '#f48fb1', # rose clair
}

# Fonction principale d'affichage
def plot_state(n1, n2, x):
    plt.figure(figsize=(9, 5))

    x_max = min(n1, n2 / 2)
    x_clipped = min(x, x_max)  # Utiliser x limité pour éviter les bugs

    etat = quantites(n1, n2, x_clipped)
    species = ['Cu²⁺', 'OH⁻', 'Cu(OH)₂']
    positions = np.arange(len(species))
    quantites_initiales = [n1, n2, 0]

    # Informations texte au-dessus
    plt.text(
        1, 1.15 * max(n1, n2),
        f"n₀(Cu²⁺) = {n1:.2f} mmol   |   n₀(OH⁻) = {n2:.2f} mmol   |   x = {x:.2f} mmol",
        ha='center', fontsize=13, weight='bold'
    )

    # 1. Barres initiales (grises, transparentes)
    plt.bar(
        positions,
        quantites_initiales,
        width=0.5,
        color=['lightgray'] * 3,
        edgecolor='black',
        alpha=0.5,
        label='Quantité initiale'
    )

    # 2. Partie disparue (colorée plus claire)
    y = [etat[s] for s in species]
    y_disappear = [quantites_initiales[i] - y[i] for i in range(len(species))]
    plt.bar(
        positions,
        y_disappear,
        width=0.5,
        bottom=y,
        color=['#d0f0fd', '#b2fab4', 'lightgray'],
        edgecolor='black',
        label='Partie disparue'
    )

    # 3. Barres finales (couleurs principales)
    plt.bar(
        positions,
        y,
        width=0.5,
        color=[couleurs[s] for s in species],
        edgecolor='black',
        label='Quantité finale'
    )

    # Flèches Δn avec noms d'espèces
    for pos, s, n_init in zip(positions, species, quantites_initiales):
        n_final = etat[s]
        delta = n_final - n_init
        signe = '+' if delta > 0 else '−'

        x_center = pos
        y_start = max(n_init, n_final)
        y_end = min(n_init, n_final)
        y_middle = (n_init + n_final) / 2

        if n_init != n_final:
            plt.annotate(
                '',
                xy=(x_center, y_end),
                xytext=(x_center, y_start),
                arrowprops=dict(arrowstyle='<->', color='black', lw=1.5)
            )
            plt.text(
                x_center + 0.15,
                y_middle,
                f"Δn({s}) = {signe}{abs(delta):.2f} mmol",
                ha='left',
                va='center',
                fontsize=10
            )

    # Ajouter la ligne avec les quantités de chaque espèce
    quantites_actuelles = [etat['Cu²⁺'], etat['OH⁻'], etat['Cu(OH)₂']]
    plt.figtext(
        0.5, -0.12,
        f"Quantités actuelles : n(Cu²⁺) = {quantites_actuelles[0]:.2f} mmol, "
        f"n(OH⁻) = {quantites_actuelles[1]:.2f} mmol, "
        f"n(Cu(OH)₂) = {quantites_actuelles[2]:.2f} mmol",
        ha='center', fontsize=12, weight='bold'
    )

    # Ajouter xmax et xf
    if x >= x_max:
        plt.figtext(
            0.5, -0.16,
            f"xmax = xf = {x_max:.2f} mmol" ,
            ha='center', fontsize=12, color='green', weight='bold'
        )

    plt.xticks(positions, species, fontsize=12)
    plt.ylabel("Quantité de matière (mmol)", fontsize=12)
    plt.title("Évolution des quantités de matière", fontsize=14, weight='bold')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.ylim(0, 1.3 * max(n1, n2))

    # Message d'erreur si besoin
    if x > x_max:
        plt.figtext(
            0.5, -0.08,
            "Avancement trop grand ! La réaction est terminée.",
            ha='center',
            fontsize=13,
            color='red',
            weight='bold'
        )

    plt.tight_layout()
    plt.show()
    plt.close()

# --- Interface ipywidgets ---
n1_slider = widgets.FloatSlider(
    value=5.0,
    min=0,
    max=10,
    step=0.1,
    description='n₀(Cu²⁺)',
    continuous_update=False,
    style={'description_width': '80px'}
)
n2_slider = widgets.FloatSlider(
    value=5.0,
    min=0,
    max=10,
    step=0.1,
    description='n₀(OH⁻)',
    continuous_update=False,
    style={'description_width': '80px'}
)

x_slider = widgets.FloatSlider(
    value=0,
    min=0,
    max=3,
    step=0.1,
    description='x',
    continuous_update=True,
    style={'description_width': '30px'}
)

# Mise à jour de x quand n1 ou n2 change
def update_x_slider(n1, n2):
    min_val = min(n1, n2 / 2)
    x_slider.max = 1.2 * min_val
    x_slider.value = 0

def on_change(change):
    update_x_slider(n1_slider.value, n2_slider.value)

n1_slider.observe(on_change, names='value')
n2_slider.observe(on_change, names='value')

# Assemblage de l'interface
controls = HBox([n1_slider, n2_slider, x_slider])
out = interactive_output(plot_state, {'n1': n1_slider, 'n2': n2_slider, 'x': x_slider})

display(VBox([controls, out]))
